{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import nibabel as nib\n",
    "import numpy as np\n",
    "\n",
    "# Fix random seed for reproducibility?\n",
    "# Better to follow the advice in Keras FAQ:\n",
    "#  \"How can I obtain reproducible results using Keras during development?\"\n",
    "seed = 7\n",
    "np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_classes = 3\n",
    "\n",
    "patience = 1\n",
    "model_filename = 'models/iSeg2017/outrun_step_{}.h5'\n",
    "csv_filename = 'log/iSeg2017/outrun_step_{}.cvs'\n",
    "\n",
    "nb_epoch = 20\n",
    "validation_split = 0.25\n",
    "\n",
    "class_mapper = {0 : 0, 10 : 0, 150 : 1, 250 : 2}\n",
    "class_mapper_inv = {0 : 0, 1 : 10, 2 : 150, 3 : 250}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# General utils for reading and saving data\n",
    "def get_filename(set_name, case_idx, input_name, loc='datasets') :\n",
    "    pattern = '{0}/iSeg2017/iSeg-2017-{1}/subject-{2}-{3}.hdr'\n",
    "    return pattern.format(loc, set_name, case_idx, input_name)\n",
    "\n",
    "def get_set_name(case_idx) :\n",
    "    return 'Training' if case_idx < 11 else 'Testing'\n",
    "\n",
    "def read_data(case_idx, input_name, loc='datasets') :\n",
    "    set_name = get_set_name(case_idx)\n",
    "\n",
    "    image_path = get_filename(set_name, case_idx, input_name, loc)\n",
    "\n",
    "    return nib.load(image_path)\n",
    "\n",
    "def read_vol(case_idx, input_name, loc='datasets') :\n",
    "    image_data = read_data(case_idx, input_name, loc)\n",
    "\n",
    "    return image_data.get_data()[:, :, :, 0]\n",
    "\n",
    "def save_vol(segmentation, case_idx, loc='results') :\n",
    "    set_name = get_set_name(case_idx)\n",
    "    input_image_data = read_data(case_idx, 'T1')\n",
    "\n",
    "    segmentation_vol = np.empty(input_image_data.shape)\n",
    "    segmentation_vol[:144, :192, :256, 0] = segmentation\n",
    "    \n",
    "    filename = get_filename(set_name, case_idx, 'label', loc)\n",
    "    nib.save(nib.analyze.AnalyzeImage(\n",
    "        segmentation_vol.astype('uint8'), input_image_data.affine), filename)\n",
    "\n",
    "\n",
    "# Data preparation utils\n",
    "from keras.utils import np_utils\n",
    "from sklearn.feature_extraction.image import extract_patches as sk_extract_patches\n",
    "\n",
    "def extract_patches(volume, patch_shape, extraction_step) :\n",
    "    patches = sk_extract_patches(\n",
    "        volume,\n",
    "        patch_shape=patch_shape,\n",
    "        extraction_step=extraction_step)\n",
    "\n",
    "    ndim = len(volume.shape)\n",
    "    npatches = np.prod(patches.shape[:ndim])\n",
    "    return patches.reshape((npatches, ) + patch_shape)\n",
    "\n",
    "def build_set(T1_vols, T2_vols, label_vols, extraction_step=(9, 9, 9)) :\n",
    "    patch_shape = (27, 27, 27)\n",
    "    label_selector = [slice(None)] + [slice(9, 18) for i in range(3)]\n",
    "\n",
    "    # Extract patches from input volumes and ground truth\n",
    "    x = np.zeros((0, 2, 27, 27, 27))\n",
    "    y = np.zeros((0, 9 * 9 * 9, num_classes))\n",
    "    for idx in range(len(T1_vols)) :\n",
    "        y_length = len(y)\n",
    "\n",
    "        label_patches = extract_patches(label_vols[idx], patch_shape, extraction_step)\n",
    "        label_patches = label_patches[label_selector]\n",
    "\n",
    "        # Select only those who are important for processing\n",
    "        valid_idxs = np.where(np.sum(label_patches, axis=(1, 2, 3)) != 0)\n",
    "\n",
    "        # Filtering extracted patches\n",
    "        label_patches = label_patches[valid_idxs]\n",
    "\n",
    "        x = np.vstack((x, np.zeros((len(label_patches), 2, 27, 27, 27))))\n",
    "        y = np.vstack((y, np.zeros((len(label_patches), 9 * 9 * 9, num_classes))))\n",
    "\n",
    "        for i in range(len(label_patches)) :\n",
    "            y[i+y_length, :, :] = np_utils.to_categorical(label_patches[i].flatten(), num_classes)\n",
    "\n",
    "        del label_patches\n",
    "\n",
    "        # Sampling strategy: reject samples which labels are only zeros\n",
    "        T1_train = extract_patches(T1_vols[idx], patch_shape, extraction_step)\n",
    "        x[y_length:, 0, :, :, :] = T1_train[valid_idxs]\n",
    "        del T1_train\n",
    "\n",
    "        # Sampling strategy: reject samples which labels are only zeros\n",
    "        T2_train = extract_patches(T2_vols[idx], patch_shape, extraction_step)\n",
    "        x[y_length:, 1, :, :, :] = T2_train[valid_idxs]\n",
    "        del T2_train\n",
    "    return x, y\n",
    "\n",
    "# Reconstruction utils\n",
    "import itertools\n",
    "\n",
    "def generate_indexes(patch_shape, expected_shape) :\n",
    "    ndims = len(patch_shape)\n",
    "\n",
    "    poss_shape = [patch_shape[i+1] * (expected_shape[i] // patch_shape[i+1]) for i in range(ndims-1)]\n",
    "\n",
    "    idxs = [range(patch_shape[i+1], poss_shape[i] - patch_shape[i+1], patch_shape[i+1]) for i in range(ndims-1)]\n",
    "\n",
    "    return itertools.product(*idxs)\n",
    "\n",
    "def reconstruct_volume(patches, expected_shape) :\n",
    "    patch_shape = patches.shape\n",
    "\n",
    "    assert len(patch_shape) - 1 == len(expected_shape)\n",
    "\n",
    "    reconstructed_img = np.zeros(expected_shape)\n",
    "\n",
    "    for count, coord in enumerate(generate_indexes(patch_shape, expected_shape)) :\n",
    "        selection = [slice(coord[i], coord[i] + patch_shape[i+1]) for i in range(len(coord))]\n",
    "        reconstructed_img[selection] = patches[count]\n",
    "\n",
    "    return reconstructed_img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import backend as K\n",
    "from keras.layers import Activation\n",
    "from keras.layers import Input\n",
    "from keras.layers.advanced_activations import PReLU\n",
    "from keras.layers.convolutional import Conv3D\n",
    "from keras.layers.convolutional import Cropping3D\n",
    "from keras.layers.core import Permute\n",
    "from keras.layers.core import Reshape\n",
    "from keras.layers.merge import concatenate\n",
    "from keras.models import Model\n",
    "\n",
    "K.set_image_dim_ordering('th')\n",
    "\n",
    "# For understanding the architecture itself, I recommend checking the following article\n",
    "# Dolz, J. et al. 3D fully convolutional networks for subcortical segmentation in MRI :\n",
    "# A large-scale study. Neuroimage, 2017.\n",
    "def generate_model(num_classes) :\n",
    "    init_input = Input((2, 27, 27, 27))\n",
    "\n",
    "    x = Conv3D(25, kernel_size=(3, 3, 3))(init_input)\n",
    "    x = PReLU()(x)\n",
    "    x = Conv3D(25, kernel_size=(3, 3, 3))(x)\n",
    "    x = PReLU()(x)\n",
    "    x = Conv3D(25, kernel_size=(3, 3, 3))(x)\n",
    "    x = PReLU()(x)\n",
    "\n",
    "    y = Conv3D(50, kernel_size=(3, 3, 3))(x)\n",
    "    y = PReLU()(y)\n",
    "    y = Conv3D(50, kernel_size=(3, 3, 3))(y)\n",
    "    y = PReLU()(y)\n",
    "    y = Conv3D(50, kernel_size=(3, 3, 3))(y)\n",
    "    y = PReLU()(y)\n",
    "\n",
    "    z = Conv3D(75, kernel_size=(3, 3, 3))(y)\n",
    "    z = PReLU()(z)\n",
    "    z = Conv3D(75, kernel_size=(3, 3, 3))(z)\n",
    "    z = PReLU()(z)\n",
    "    z = Conv3D(75, kernel_size=(3, 3, 3))(z)\n",
    "    z = PReLU()(z)\n",
    "\n",
    "    x_crop = Cropping3D(cropping=((6, 6), (6, 6), (6, 6)))(x)\n",
    "    y_crop = Cropping3D(cropping=((3, 3), (3, 3), (3, 3)))(y)\n",
    "\n",
    "    concat = concatenate([x_crop, y_crop, z], axis=1)\n",
    "\n",
    "    fc = Conv3D(400, kernel_size=(1, 1, 1))(concat)\n",
    "    fc = PReLU()(fc)\n",
    "    fc = Conv3D(200, kernel_size=(1, 1, 1))(fc)\n",
    "    fc = PReLU()(fc)\n",
    "    fc = Conv3D(150, kernel_size=(1, 1, 1))(fc)\n",
    "    fc = PReLU()(fc)\n",
    "\n",
    "    pred = Conv3D(num_classes, kernel_size=(1, 1, 1))(fc)\n",
    "    pred = PReLU()(pred)\n",
    "    pred = Reshape((num_classes, 9 * 9 * 9))(pred)\n",
    "    pred = Permute((2, 1))(pred)\n",
    "    pred = Activation('softmax')(pred)\n",
    "\n",
    "    model = Model(inputs=init_input, outputs=pred)\n",
    "    model.compile(\n",
    "        loss='categorical_crossentropy',\n",
    "        optimizer='adam',\n",
    "        metrics=['categorical_accuracy'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Initial segmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.1 Read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "T1_vols = np.empty((10, 144, 192, 256))\n",
    "T2_vols = np.empty((10, 144, 192, 256))\n",
    "label_vols = np.empty((10, 144, 192, 256))\n",
    "for case_idx in range(1, 11) :\n",
    "    T1_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T1')\n",
    "    T2_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T2')\n",
    "    label_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'label')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.2 Pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "## Intensity normalisation (zero mean and unit variance)\n",
    "T1_mean = T1_vols.mean()\n",
    "T1_std = T1_vols.std()\n",
    "T1_vols = (T1_vols - T1_mean) / T1_std\n",
    "T2_mean = T2_vols.mean()\n",
    "T2_std = T2_vols.std()\n",
    "T2_vols = (T2_vols - T2_mean) / T2_std\n",
    "\n",
    "# Combine labels of BG and CSF\n",
    "for class_idx in class_mapper :\n",
    "    label_vols[label_vols == class_idx] = class_mapper[class_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.3 Data preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train, y_train = build_set(T1_vols, T2_vols, label_vols, (3, 9, 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.4 Configure callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.callbacks import ModelCheckpoint\n",
    "from keras.callbacks import CSVLogger\n",
    "from keras.callbacks import EarlyStopping\n",
    "\n",
    "# Early stopping for reducing over-fitting risk\n",
    "stopper = EarlyStopping(patience=patience)\n",
    "\n",
    "# Model checkpoint to save the training results\n",
    "checkpointer = ModelCheckpoint(\n",
    "    filepath=model_filename.format(1),\n",
    "    verbose=0,\n",
    "    save_best_only=True,\n",
    "    save_weights_only=True)\n",
    "\n",
    "# CSVLogger to save the training results in a csv file\n",
    "csv_logger = CSVLogger(csv_filename.format(1), separator=';')\n",
    "\n",
    "callbacks = [checkpointer, csv_logger, stopper]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.5 Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build model\n",
    "model = generate_model(num_classes)\n",
    "\n",
    "model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    epochs=nb_epoch,\n",
    "    validation_split=validation_split,\n",
    "    verbose=2,\n",
    "    callbacks=callbacks)\n",
    "\n",
    "# freeing space\n",
    "del x_train\n",
    "del y_train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.6 Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.models import load_model\n",
    "\n",
    "# Load best model\n",
    "model = generate_model(num_classes)\n",
    "model.load_weights(model_filename.format(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for case_idx in range(11, 24) :\n",
    "    T1_test_vol = read_vol(case_idx, 'T1')[:144, :192, :256]\n",
    "    T2_test_vol = read_vol(case_idx, 'T2')[:144, :192, :256]\n",
    "    \n",
    "    x_test = np.zeros((6916, 2, 27, 27, 27))\n",
    "    x_test[:, 0, :, :, :] = extract_patches(T1_test_vol, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9))\n",
    "    x_test[:, 1, :, :, :] = extract_patches(T2_test_vol, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9))\n",
    "    \n",
    "    x_test[:, 0, :, :, :] = (x_test[:, 0, :, :, :] - T1_mean) / T1_std\n",
    "    x_test[:, 1, :, :, :] = (x_test[:, 1, :, :, :] - T2_mean) / T2_std\n",
    "\n",
    "    pred = model.predict(x_test, verbose=2)\n",
    "    pred_classes = np.argmax(pred, axis=2)\n",
    "    pred_classes = pred_classes.reshape((len(pred_classes), 9, 9, 9))\n",
    "    segmentation = reconstruct_volume(pred_classes, (144, 192, 256))\n",
    "    \n",
    "    csf = np.logical_and(segmentation == 0, T1_test_vol != 0)\n",
    "    segmentation[segmentation == 2] = 250\n",
    "    segmentation[segmentation == 1] = 150\n",
    "    segmentation[csf] = 10\n",
    "    \n",
    "    save_vol(segmentation, case_idx)\n",
    "    \n",
    "    print \"Finished segmentation of case # {}\".format(case_idx)\n",
    "\n",
    "print \"Done with Step 1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Pseudo-labelling step"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 Read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sure = range(0, 10)\n",
    "unsure = range(11, 23)\n",
    "\n",
    "T1_vols = np.empty((23, 144, 192, 256))\n",
    "T2_vols = np.empty((23, 144, 192, 256))\n",
    "label_vols = np.empty((23, 144, 192, 256))\n",
    "for case_idx in range(1, 24) :\n",
    "    loc = 'datasets' if case_idx < 11 else 'results'\n",
    "\n",
    "    T1_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T1')[:144, :192, :256]\n",
    "    T2_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T2')[:144, :192, :256]\n",
    "    label_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'label', loc)[:144, :192, :256]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 Pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "## Intensity normalisation (zero mean and unit variance)\n",
    "T1_mean = T1_vols.mean()\n",
    "T1_std = T1_vols.std()\n",
    "T1_vols = (T1_vols - T1_mean) / T1_std\n",
    "T2_mean = T2_vols.mean()\n",
    "T2_std = T2_vols.std()\n",
    "T2_vols = (T2_vols - T2_mean) / T2_std\n",
    "\n",
    "# Combine labels of BG and CSF\n",
    "for class_idx in class_mapper :\n",
    "    label_vols[label_vols == class_idx] = class_mapper[class_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.3 Data preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_sure, y_sure = build_set(T1_vols[sure], T2_vols[sure], label_vols[sure], (3, 9, 3))\n",
    "x_unsure, y_unsure = build_set(T1_vols[unsure], T2_vols[unsure], label_vols[unsure])\n",
    "\n",
    "x_train = np.vstack((x_sure, x_unsure))\n",
    "y_train = np.vstack((y_sure, y_unsure))\n",
    "\n",
    "del x_sure\n",
    "del x_unsure\n",
    "del y_sure\n",
    "del y_unsure"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.4 Configure callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.callbacks import ModelCheckpoint\n",
    "from keras.callbacks import CSVLogger\n",
    "from keras.callbacks import EarlyStopping\n",
    "\n",
    "# Early stopping for reducing over-fitting risk\n",
    "stopper = EarlyStopping(patience=patience)\n",
    "\n",
    "# Model checkpoint to save the training results\n",
    "checkpointer = ModelCheckpoint(\n",
    "    filepath=model_filename.format(2),\n",
    "    verbose=0,\n",
    "    save_best_only=True,\n",
    "    save_weights_only=True)\n",
    "\n",
    "# CSVLogger to save the training results in a csv file\n",
    "csv_logger = CSVLogger(csv_filename.format(2), separator=';')\n",
    "\n",
    "callbacks = [checkpointer, csv_logger, stopper]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.5 Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build model\n",
    "model = generate_model(num_classes)\n",
    "\n",
    "model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    epochs=nb_epoch,\n",
    "    validation_split=validation_split,\n",
    "    verbose=2,\n",
    "    callbacks=callbacks)\n",
    "\n",
    "# freeing space\n",
    "del x_train\n",
    "del y_train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.6 Clasification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.models import load_model\n",
    "\n",
    "# Load best model\n",
    "model = generate_model(num_classes)\n",
    "model.load_weights(model_filename.format(2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for case_idx in range(11, 24) :\n",
    "    T1_test_vol = read_vol(case_idx, 'T1')[:144, :192, :256]\n",
    "    T2_test_vol = read_vol(case_idx, 'T2')[:144, :192, :256]\n",
    "    \n",
    "    x_test = np.zeros((6916, 2, 27, 27, 27))\n",
    "    x_test[:, 0, :, :, :] = extract_patches(T1_test_vol, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9))\n",
    "    x_test[:, 1, :, :, :] = extract_patches(T2_test_vol, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9))\n",
    "    \n",
    "    x_test[:, 0, :, :, :] = (x_test[:, 0, :, :, :] - T1_mean) / T1_std\n",
    "    x_test[:, 1, :, :, :] = (x_test[:, 1, :, :, :] - T2_mean) / T2_std\n",
    "\n",
    "    pred = model.predict(x_test, verbose=2)\n",
    "    pred_classes = np.argmax(pred, axis=2)\n",
    "    pred_classes = pred_classes.reshape((len(pred_classes), 9, 9, 9))\n",
    "    segmentation = reconstruct_volume(pred_classes, (144, 192, 256))\n",
    "    \n",
    "    csf = np.logical_and(segmentation == 0, T1_test_vol != 0)\n",
    "    segmentation[segmentation == 2] = 250\n",
    "    segmentation[segmentation == 1] = 150\n",
    "    segmentation[csf] = 10\n",
    "    \n",
    "    save_vol(segmentation, case_idx, 'refined-results')\n",
    "    \n",
    "    print \"Finished segmentation of case # {}\".format(case_idx)\n",
    "\n",
    "print \"Done with Step 2\""
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
